#!/usr/bin/env python3
from __future__ import annotations
import os
# ---- determinism hygiene: set before importing numpy/torch/scikit-learn ----
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("PYTHONHASHSEED", "0")

import argparse, json
from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
from tqdm.auto import tqdm

FOUR_TAGS = [
    "final_answer",
    "setup_and_retrieval",
    "analysis_and_computation",
    "uncertainty_and_verification",
]

def load_pt_records(pt_path: str):
    obj = torch.load(pt_path, map_location="cpu")
    return obj["records"]

def to_f32_np(x):
    try:
        if isinstance(x, torch.Tensor):
            return x.detach().to(dtype=torch.float32, device="cpu").contiguous().numpy()
        if isinstance(x, (list, tuple)) and len(x) > 0 and isinstance(x[0], torch.Tensor):
            xt = torch.stack([t.detach().to(dtype=torch.float32, device="cpu") for t in x], dim=0)
            return xt.contiguous().numpy()
    except Exception:
        pass
    arr = np.asarray(x)
    if arr.dtype != np.float32:
        arr = arr.astype(np.float32, copy=False)
    return arr

def load_preproc(model_npz_path: str):
    z = np.load(model_npz_path, allow_pickle=True)
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    scaler.mean_  = z["prep_mean"]
    scaler.scale_ = z["prep_scale"]
    scaler.var_   = scaler.scale_ ** 2
    scaler.n_features_in_ = scaler.mean_.shape[0]
    pca = None
    if "prep_pca_components" in z.files and z["prep_pca_components"].size > 0:
        from sklearn.decomposition import PCA
        comps = z["prep_pca_components"]; mean = z["prep_pca_mean"]
        k = int(comps.shape[0]); Din = int(mean.shape[0])
        pca = PCA(n_components=k, svd_solver="full")
        pca.components_ = comps
        pca.mean_ = mean
        pca.n_features_in_ = Din
        if "prep_pca_explained_variance" in z.files:
            pca.explained_variance_ = z["prep_pca_explained_variance"]
            pca.explained_variance_ratio_ = z["prep_pca_explained_variance_ratio"]
            pca.singular_values_ = z["prep_pca_singular_values"]
        else:
            pca.explained_variance_ = np.ones(k)
            pca.explained_variance_ratio_ = np.ones(k)/k
            pca.singular_values_ = np.ones(k)
    return scaler, pca

def apply_preproc_step(H: np.ndarray, scaler, pca) -> np.ndarray:
    D_expected = scaler.n_features_in_
    if H.shape[-1] != D_expected:
        raise ValueError(f"[preproc] Hidden state dim {H.shape[-1]} != scaler.n_features_in_ {D_expected}")
    X = scaler.transform(H)
    if pca is not None:
        if X.shape[-1] != pca.n_features_in_:
            raise ValueError(f"[preproc] Scaled dim {X.shape[-1]} != pca.n_features_in_ {pca.n_features_in_}")
        X = pca.transform(X)
    return X.astype(np.float64)

def regime_transition_layer(regs: List[int]) -> int:
    if not regs:
        return 0
    r0 = int(regs[0])
    for i, r in enumerate(regs[1:], 1):
        if int(r) != r0:
            return i
    return len(regs)//2

def unit_var_normalize(v: np.ndarray, eps=1e-8):
    """Normalize vector to zero mean and unit variance."""
    v = v - np.nanmean(v)
    s = np.nanstd(v)
    if not np.isfinite(s) or s < eps:
        return np.zeros_like(v)
    return v / (s + eps)

def pick_layer_index(regs_per_step: List[int], Z: np.ndarray,
                     use_median_layer: bool = False,
                     use_last_layer: bool = False) -> int:
    if use_last_layer:
        return Z.shape[0] - 1
    if use_median_layer:
        return Z.shape[0] // 2
    return min(regime_transition_layer(regs_per_step), Z.shape[0] - 1)

# ---------- alignment helpers ----------
_ID_KEYS = ("sample_id", "i", "id")

def _rec_sort_key(r):
    return (
        r.get("sample_id", None),
        r.get("i", None),
        r.get("id", None),
        (r.get("prompt", "") or "")[:32],
    )

def _key_from(d):
    return tuple(d.get(k, None) for k in _ID_KEYS)

def _decoded_has_ids(decoded_item: dict) -> bool:
    return any(k in decoded_item for k in _ID_KEYS)

def align_records_and_decoded(recs: List[dict], decoded: List[dict]) -> Tuple[List[dict], List[dict], int]:
    if not decoded:
        return [], [], len(recs)
    # Prefer join-by-ID when IDs exist
    if _decoded_has_ids(decoded[0]):
        idx = {_key_from(s): s for s in decoded}
        aligned_r, aligned_s, missing = [], [], 0
        for r in recs:
            key = _key_from(r)
            s = idx.get(key)
            if s is not None:
                aligned_r.append(r); aligned_s.append(s)
            else:
                missing += 1
        return aligned_r, aligned_s, missing
    # Fallback: sort like decode then align by order
    recs_sorted = sorted(recs, key=_rec_sort_key)
    n = min(len(recs_sorted), len(decoded))
    return recs_sorted[:n], decoded[:n], max(0, len(recs_sorted) - n)

# ---------- Edge stats ----------
def build_edge_stats(records, decoded_sequences, scaler, pca,
                     use_median_layer=False, use_last_layer=False, show_progress=False):
    C = 4
    M_corr = np.zeros((C, C), dtype=np.int64)
    M_inc  = np.zeros((C, C), dtype=np.int64)

    sum_corr = None          # [C, C, D*]
    cnt_corr = np.zeros((C, C), dtype=np.int64)
    sum_inc_edge   = None    # [C, C, D*]
    cnt_inc_edge   = np.zeros((C, C), dtype=np.int64)
    sum_inc_source = None    # [C, D*]
    cnt_inc_source = np.zeros((C,), dtype=np.int64)

    it = zip(records, decoded_sequences)
    if show_progress:
        total = min(len(records), len(decoded_sequences))
        it = tqdm(it, total=total, desc="Edge stats (disp + srcneg)", unit="seq")

    for r, s in it:
        ok = bool(r.get("is_correct", False))
        cats = s.get("best_categories", [])
        regs_per = s.get("best_regimes_per_step", [])
        hs_list = r.get("step_hidden_states", [])
        T = min(len(cats), len(hs_list))
        if T < 2:
            continue

        M = M_corr if ok else M_inc
        for a, b in zip(cats[:T-1], cats[1:T]):
            M[int(a), int(b)] += 1

        for t in range(T-1):
            a, b = int(cats[t]), int(cats[t+1])

            H_t  = to_f32_np(hs_list[t])
            Z_t  = apply_preproc_step(H_t, scaler, pca)
            l_t  = pick_layer_index(regs_per[t] if t < len(regs_per) else [],
                                    Z_t,
                                    use_median_layer=use_median_layer,
                                    use_last_layer=use_last_layer)
            v_t  = Z_t[l_t]

            H_t1 = to_f32_np(hs_list[t+1])
            Z_t1 = apply_preproc_step(H_t1, scaler, pca)
            l_t1 = min(l_t, Z_t1.shape[0]-1)
            v_t1 = Z_t1[l_t1]

            dv = v_t1 - v_t

            if sum_corr is None:
                Dstar = Z_t.shape[1]
                sum_corr       = np.zeros((C, C, Dstar), np.float64)
                sum_inc_edge   = np.zeros((C, C, Dstar), np.float64)
                sum_inc_source = np.zeros((C,     Dstar), np.float64)

            if ok:
                sum_corr[a, b] += dv
                cnt_corr[a, b] += 1
            else:
                sum_inc_edge[a, b] += dv
                cnt_inc_edge[a, b] += 1
                sum_inc_source[a]  += dv
                cnt_inc_source[a]  += 1

    def safe_mean_edge(S, N):
        if S is None: return None
        with np.errstate(divide='ignore', invalid='ignore'):
            out = S / N[..., None]
        out[np.isnan(out)] = np.nan
        return out

    def safe_mean_src(S, N):
        if S is None: return None
        with np.errstate(divide='ignore', invalid='ignore'):
            out = S / N[:, None]
        out[np.isnan(out)] = np.nan
        return out

    CorrMean_edge = safe_mean_edge(sum_corr,     cnt_corr)
    IncMean_edge  = safe_mean_edge(sum_inc_edge, cnt_inc_edge)
    IncMean_src   = safe_mean_src (sum_inc_source, cnt_inc_source)

    Delta = None
    if CorrMean_edge is not None and IncMean_edge is not None and IncMean_src is not None:
        Cdim = CorrMean_edge.shape[0]
        Ddim = CorrMean_edge.shape[-1]
        IncMean_srcneg = np.full((Cdim, Cdim, Ddim), np.nan, dtype=np.float64)
        for a in range(Cdim):
            S_a, N_a = sum_inc_source[a], cnt_inc_source[a]
            for b in range(Cdim):
                S_ab, N_ab = sum_inc_edge[a, b], cnt_inc_edge[a, b]
                N_ex = N_a - N_ab
                if N_ex > 0:
                    IncMean_srcneg[a, b] = (S_a - S_ab) / float(N_ex)
                else:
                    if N_a > 0:
                        IncMean_srcneg[a, b] = S_a / float(N_a)
                    else:
                        IncMean_srcneg[a, b] = np.zeros((Ddim,), dtype=np.float64)
        Delta = CorrMean_edge - IncMean_srcneg

    return {
        "M_correct": M_corr,
        "M_incorrect": M_inc,
        "Delta": Delta,
        "FOUR_TAGS": np.array(FOUR_TAGS, dtype=object),
        "CorrMean_edge": CorrMean_edge,
        "IncMean_edge": IncMean_edge,
        "IncMean_src": IncMean_src,
    }

# ---------- Baseline (edge-agnostic) ----------
def build_baseline_all_steps(records, decoded_sequences, scaler, pca,
                             use_median_layer=False, use_last_layer=False,
                             show_progress=False):
    sum_corr = None; sum_inc  = None
    cnt_corr = 0;    cnt_inc  = 0
    it = zip(records, decoded_sequences)
    if show_progress:
        total = min(len(records), len(decoded_sequences))
        it = tqdm(it, total=total, desc="Baseline (all-steps)", unit="seq")

    for r, s in it:
        ok = bool(r.get("is_correct", False))
        regs_per = s.get("best_regimes_per_step", [])
        hs_list = r.get("step_hidden_states", [])
        T = min(len(hs_list), len(s.get("best_categories", [])))
        if T < 1:
            continue
        for t in range(T):
            H = to_f32_np(hs_list[t])
            Z = apply_preproc_step(H, scaler, pca)
            lstar = pick_layer_index(regs_per[t] if t < len(regs_per) else [],
                                     Z,
                                     use_median_layer=use_median_layer,
                                     use_last_layer=use_last_layer)
            v = Z[lstar]
            if sum_corr is None:
                Dstar = Z.shape[1]
                sum_corr = np.zeros((Dstar,), np.float64)
                sum_inc  = np.zeros((Dstar,), np.float64)
            if ok:
                sum_corr += v; cnt_corr += 1
            else:
                sum_inc  += v; cnt_inc  += 1

    def _mean(S, N):
        if S is None or N == 0: return None
        out = S / float(N)
        return np.asarray(out, dtype=np.float64)

    V_corr_global = _mean(sum_corr, cnt_corr)
    V_inc_global  = _mean(sum_inc , cnt_inc )
    if V_corr_global is None and V_inc_global is None:
        Delta_global = None
    else:
        a = np.zeros_like(V_inc_global) if V_corr_global is None else V_corr_global
        b = np.zeros_like(V_corr_global) if V_inc_global  is None else V_inc_global
        Delta_global = a - b

    return {
        "V_corr_global": V_corr_global,
        "V_inc_global": V_inc_global,
        "Delta_global": Delta_global,
        "cnt_corr": cnt_corr,
        "cnt_inc": cnt_inc,
    }

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--pt", required=True)
    ap.add_argument("--decoded_json", required=True)
    ap.add_argument("--model_npz", required=True)
    ap.add_argument("--out_npz", required=True)
    ap.add_argument("--edge", action="append", required=True,
                    help="Edge pair s,t (0-3) e.g. --edge 1,3; can repeat.")
    ap.add_argument("--use_median_layer", action="store_true")
    ap.add_argument("--normalize", action="store_true",
                    help="Apply unit variance normalization to steering vectors")
    ap.add_argument("--progress", action="store_true")
    ap.add_argument("--use_last_layer", action="store_true",
                    help="Always pick the last layer for steering vector")
    ap.add_argument("--baseline_all_steps", action="store_true",
                    help="Also compute and save a global baseline vector using ALL steps (edge-agnostic)")
    ap.add_argument("--soft_edges_file", type=str, default=None,
                    help="Path to JSON: {\"edges\": [[i,j],...], \"weights\": [w1,...]}")

    args = ap.parse_args()

    # parse edges (order preserved → deterministic)
    edges: List[Tuple[int,int]] = []
    for e in args.edge:
        s, t = e.split(",")
        edges.append((int(s), int(t)))

    print("[build] loading preproc...", flush=True)
    scaler, pca = load_preproc(args.model_npz)

    print("[build] loading PT + decoded...", flush=True)
    recs_all = load_pt_records(args.pt)
    with open(args.decoded_json, "r") as f:
        decoded_blob = json.load(f)
    decoded_all = decoded_blob["sequences"]

    subset = decoded_blob.get("subset", "all")
    if subset not in {"all", "correct", "incorrect"}:
        print(f"[warn] decoded subset tag unusual: {subset!r}", flush=True)
    recs = recs_all if subset == "all" else [
        r for r in recs_all if bool(r.get("is_correct", False)) == (subset == "correct")
    ]

    print(f"[build] subset={subset}, n_records_before_align={len(recs)}, n_decoded={len(decoded_all)}", flush=True)

    # Robust alignment by IDs (fallback to order if IDs missing)
    recs_aligned, decoded_aligned, missing = align_records_and_decoded(recs, decoded_all)
    if missing:
        print(f"[warn] {missing} PT records had no decoded match; skipped.", flush=True)
    if len(recs_aligned) == 0:
        raise RuntimeError("[build] No aligned pairs after join; check decoded JSON vs PT & subset.")

    print(f"[build] aligned_pairs={len(recs_aligned)}", flush=True)
    r0, s0 = recs_aligned[0], decoded_aligned[0]
    print(f"[build] first_pair: steps_in_PT={len(r0.get('step_hidden_states', []))}  "
          f"best_categories_len={len(s0.get('best_categories', []))}", flush=True)

    print("[build] computing edge stats (displacement + source-conditioned negative)...", flush=True)
    stats = build_edge_stats(recs_aligned, decoded_aligned, scaler, pca,
                             use_median_layer=args.use_median_layer,
                             use_last_layer=args.use_last_layer,
                             show_progress=args.progress)

    vectors = {}
    Delta = stats["Delta"]
    def safe_edge(i,j):
        if Delta is None: return np.zeros(1, dtype=np.float64)
        v = Delta[i,j]
        return np.zeros(Delta.shape[-1], np.float64) if v is None or np.isnan(v).all() else np.nan_to_num(v, nan=0.0)

    for (s,t) in edges:
        v = safe_edge(s,t)
        if args.normalize:
            v = unit_var_normalize(v)
        vectors[f"vec::edge_delta:{s},{t}"] = v

    if args.baseline_all_steps:
        print("[build] computing baseline (all-steps)...", flush=True)
        base = build_baseline_all_steps(recs_aligned, decoded_aligned, scaler, pca,
                                        use_median_layer=args.use_median_layer,
                                        use_last_layer=args.use_last_layer,
                                        show_progress=args.progress)
        baseline_vec = np.zeros(1, dtype=np.float64)
        if base.get("Delta_global") is not None:
            baseline_vec = np.nan_to_num(base["Delta_global"], nan=0.0)
            if args.normalize:
                baseline_vec = unit_var_normalize(baseline_vec)
        vectors["vec::baseline_all_steps"] = baseline_vec

    soft_edges_arr = None
    soft_weights_arr = None
    if args.soft_edges_file is not None and os.path.isfile(args.soft_edges_file):
        try:
            with open(args.soft_edges_file, "r") as f:
                blob = json.load(f)
            edges_list = blob.get("edges", [])
            weights = blob.get("weights", [])
            if len(edges_list) == len(weights) and len(edges_list) > 0:
                soft_edges_arr = np.asarray(edges_list, dtype=np.int64).reshape(-1, 2)
                soft_weights_arr = np.asarray(weights, dtype=np.float64).reshape(-1)
                soft_weights_arr = np.clip(soft_weights_arr, 0.0, None)
                s = soft_weights_arr.sum()
                soft_weights_arr = (soft_weights_arr / s) if s > 0 else np.ones_like(soft_weights_arr) / soft_weights_arr.size
                print(f"[build] loaded soft-edges K={soft_edges_arr.shape[0]} from {args.soft_edges_file}", flush=True)
            else:
                print("[build] soft_edges_file malformed; skipping.", flush=True)
        except Exception as e:
            print(f"[build] failed to load soft_edges_file: {e}", flush=True)

    payload = {
        "M_correct": stats["M_correct"],
        "M_incorrect": stats["M_incorrect"],
        "Delta": stats["Delta"],
        "FOUR_TAGS": stats["FOUR_TAGS"],
        **vectors,
    }
    if soft_edges_arr is not None:
        payload["soft_edges"] = soft_edges_arr
        payload["soft_weights"] = soft_weights_arr

    os.makedirs(os.path.dirname(args.out_npz), exist_ok=True)
    np.savez(args.out_npz, **payload)
    print(f"[build] wrote {args.out_npz} (keys: {sorted(payload.keys())})", flush=True)

if __name__ == "__main__":
    main()
